{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a99fc380-490f-417f-8c84-dd223295b337",
   "metadata": {},
   "outputs": [],
   "source": [
    "import altair as alt\n",
    "import panel as pn\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.cluster import KMeans\n",
    "\n",
    "pn.extension('tabulator', 'vega', design='material', template='material')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d3fef85-1cc7-4f55-bc05-418f05c14a71",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bf759c7-117f-43bc-9f4a-95e8dfecfb8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "penguins = pn.cache(pd.read_csv)('https://datasets.holoviz.org/penguins/v1/penguins.csv').dropna()\n",
    "cols = list(penguins.columns)[2:6]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f9fd6ea-1b56-4d5f-85ab-6f3a3cc9f559",
   "metadata": {},
   "source": [
    "## Define application"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7061fb2e-3c93-4144-a321-bc0e17e4539d",
   "metadata": {},
   "outputs": [],
   "source": [
    "@pn.cache\n",
    "def get_clusters(n_clusters):\n",
    "    kmeans = KMeans(n_clusters=n_clusters, n_init='auto')\n",
    "    est = kmeans.fit(penguins[cols].values)\n",
    "    df = penguins.copy()\n",
    "    df['labels'] = est.labels_.astype('str')\n",
    "    return df\n",
    "\n",
    "@pn.cache\n",
    "def get_chart(x, y, df):\n",
    "    centers = df.groupby('labels')[[x] if x == y else [x, y]].mean()\n",
    "    return (\n",
    "        alt.Chart(df)\n",
    "            .mark_point(size=100)\n",
    "            .encode(\n",
    "                x=alt.X(x, scale=alt.Scale(zero=False)),\n",
    "                y=alt.Y(y, scale=alt.Scale(zero=False)),\n",
    "                shape='labels',\n",
    "                color='species'\n",
    "            ).add_params(brush) +\n",
    "        alt.Chart(centers)\n",
    "            .mark_point(size=250, shape='cross', color='black')\n",
    "            .encode(x=x+':Q', y=y+':Q')\n",
    "    ).properties(width='container', height='container')\n",
    "\n",
    "intro = pn.pane.Markdown(\"\"\"\n",
    "This app provides an example of **building a simple dashboard using\n",
    "Panel**.\\n\\nIt demonstrates how to take the output of **k-means\n",
    "clustering on the Penguins dataset** using scikit-learn,\n",
    "parameterizing the number of clusters and the variables to\n",
    "plot.\\n\\nThe plot and the table are linked, i.e. selecting on the plot\n",
    "will filter the data in the table.\\n\\n The **`x` marks the center** of\n",
    "the cluster.\n",
    "\"\"\", sizing_mode='stretch_width')\n",
    "\n",
    "x = pn.widgets.Select(name='x', options=cols, value='bill_depth_mm')\n",
    "y = pn.widgets.Select(name='y', options=cols, value='bill_length_mm')\n",
    "n_clusters = pn.widgets.IntSlider(name='n_clusters', start=1, end=5, value=3)\n",
    "\n",
    "brush = alt.selection_interval(name='brush')  # selection of type \"interval\"\n",
    "\n",
    "clusters = pn.bind(get_clusters, n_clusters)\n",
    "\n",
    "chart = pn.pane.Vega(\n",
    "    pn.bind(get_chart, x, y, clusters), min_height=400, max_height=800, sizing_mode='stretch_width'\n",
    ")\n",
    "\n",
    "table = pn.widgets.Tabulator(\n",
    "    clusters,\n",
    "    pagination='remote', page_size=10, height=600,\n",
    "    sizing_mode='stretch_width'\n",
    ")\n",
    "\n",
    "def vega_filter(filters, df):\n",
    "    filtered = df\n",
    "    for field, drange in (filters or {}).items():\n",
    "        filtered = filtered[filtered[field].between(*drange)]\n",
    "    return filtered\n",
    "\n",
    "table.add_filter(pn.bind(vega_filter, chart.selection.param.brush))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23d4c182-72e5-4676-bbe0-1b461a9bcbb8",
   "metadata": {},
   "source": [
    "## Layout app"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd9eb64f-6a25-4c47-81be-0ffe62bb8620",
   "metadata": {},
   "outputs": [],
   "source": [
    "pn.Row(\n",
    "    pn.Column(x, y, n_clusters).servable(area='sidebar'),\n",
    "    pn.Column(\n",
    "        intro, chart, table,\n",
    "    ).servable(title='KMeans Clustering'),\n",
    "    sizing_mode='stretch_both',\n",
    "    min_height=1000\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
